{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/save_load_model_and_share_with_hf_hub.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import segmentation_models_pytorch as smp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save to local directory and load back"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading weights from local directory\n"
     ]
    }
   ],
   "source": [
    "model = smp.Unet()\n",
    "\n",
    "# save the model\n",
    "model.save_pretrained(\"saved-model-dir/unet/\")\n",
    "\n",
    "# load the model\n",
    "restored_model = smp.from_pretrained(\"saved-model-dir/unet/\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save model with additional metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = smp.Unet()\n",
    "\n",
    "# save the model\n",
    "model.save_pretrained(\n",
    "    \"saved-model-dir/unet-with-metadata/\",\n",
    "    # additional information to be saved with the model\n",
    "    # only \"dataset\" and \"metrics\" are supported\n",
    "    dataset=\"PASCAL VOC\",  # only string name is supported\n",
    "    metrics={  # should be a dictionary with metric name as key and metric value as value\n",
    "        \"mIoU\": 0.95,\n",
    "        \"accuracy\": 0.96,\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---\n",
      "library_name: segmentation-models-pytorch\n",
      "license: mit\n",
      "pipeline_tag: image-segmentation\n",
      "tags:\n",
      "- model_hub_mixin\n",
      "- pytorch_model_hub_mixin\n",
      "- segmentation-models-pytorch\n",
      "- semantic-segmentation\n",
      "- pytorch\n",
      "languages:\n",
      "- python\n",
      "---\n",
      "# Unet Model Card\n",
      "\n",
      "Table of Contents:\n",
      "- [Load trained model](#load-trained-model)\n",
      "- [Model init parameters](#model-init-parameters)\n",
      "- [Model metrics](#model-metrics)\n",
      "- [Dataset](#dataset)\n",
      "\n",
      "## Load trained model\n",
      "```python\n",
      "import segmentation_models_pytorch as smp\n",
      "\n",
      "model = smp.from_pretrained(\"<save-directory-or-this-repo>\")\n",
      "```\n",
      "\n",
      "## Model init parameters\n",
      "```python\n",
      "model_init_params = {\n",
      "    \"encoder_name\": \"resnet34\",\n",
      "    \"encoder_depth\": 5,\n",
      "    \"encoder_weights\": \"imagenet\",\n",
      "    \"decoder_use_batchnorm\": True,\n",
      "    \"decoder_channels\": (256, 128, 64, 32, 16),\n",
      "    \"decoder_attention_type\": None,\n",
      "    \"in_channels\": 3,\n",
      "    \"classes\": 1,\n",
      "    \"activation\": None,\n",
      "    \"aux_params\": None\n",
      "}\n",
      "```\n",
      "\n",
      "## Model metrics\n",
      "```json\n",
      "{\n",
      "    \"mIoU\": 0.95,\n",
      "    \"accuracy\": 0.96\n",
      "}\n",
      "```\n",
      "\n",
      "## Dataset\n",
      "Dataset name: PASCAL VOC\n",
      "\n",
      "## More Information\n",
      "- Library: https://github.com/qubvel/segmentation_models.pytorch\n",
      "- Docs: https://smp.readthedocs.io/en/latest/\n",
      "\n",
      "This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin)"
     ]
    }
   ],
   "source": [
    "!cat \"saved-model-dir/unet-with-metadata/README.md\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Share model with HF Hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d6fe9d868c24175aa5f23a2893a2c21",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "# You only need to run this once on the machine,\n",
    "# the token will be stored for later use\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2f4f5e4973e44f9a857e89d9ac707b53",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/97.8M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-metadata/commit/4ac3d2925d34cf183dc79a2e21b6e2f4bfe87586', commit_message='Push model using huggingface_hub.', commit_description='', oid='4ac3d2925d34cf183dc79a2e21b6e2f4bfe87586', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = smp.Unet()\n",
    "\n",
    "# save the model and share it on the HF Hub (https://huggingface.co/models)\n",
    "model.save_pretrained(\n",
    "    \"qubvel-hf/unet-with-metadata/\",\n",
    "    push_to_hub=True,  # <----------    push the model to the hub\n",
    "    private=False,  # <----------    make the model private or or public\n",
    "    dataset=\"PASCAL VOC\",\n",
    "    metrics={\"mIoU\": 0.95, \"accuracy\": 0.96},\n",
    ")\n",
    "\n",
    "# see result here https://huggingface.co/qubvel-hf/unet-with-metadata"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save model with preprocessing (using albumentations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -U albumentations numpy==1.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import albumentations as A\n",
    "import segmentation_models_pytorch as smp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define a preprocessing transform for image that would be used during inference\n",
    "preprocessing_transform = A.Compose([A.Resize(256, 256), A.Normalize()])\n",
    "\n",
    "model = smp.Unet()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1aa3f4db4cd2489baeac3b844977d5a2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/97.8M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-transform/commit/680dad16431fa6efbb25832d33a24056bdf7dc1a', commit_message='Push transform using huggingface_hub.', commit_description='', oid='680dad16431fa6efbb25832d33a24056bdf7dc1a', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "directory_or_repo_on_the_hub = \"qubvel-hf/unet-with-transform\"\n",
    "\n",
    "# save the model\n",
    "model.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True)\n",
    "\n",
    "# save transform\n",
    "preprocessing_transform.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's restore model and preprocessing transform for inference:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading weights from local directory\n",
      "Compose([\n",
      "  Resize(p=1.0, height=256, width=256, interpolation=1),\n",
      "  Normalize(p=1.0, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, normalization='standard'),\n",
      "], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)\n"
     ]
    }
   ],
   "source": [
    "restored_model = smp.from_pretrained(directory_or_repo_on_the_hub)\n",
    "restored_transform = A.Compose.from_pretrained(directory_or_repo_on_the_hub)\n",
    "\n",
    "print(restored_transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compose([\n",
      "  HorizontalFlip(p=0.5),\n",
      "  RandomBrightnessContrast(p=0.2, brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), brightness_by_max=True),\n",
      "  ShiftScaleRotate(p=0.5, shift_limit_x=(-0.0625, 0.0625), shift_limit_y=(-0.0625, 0.0625), scale_limit=(-0.09999999999999998, 0.10000000000000009), rotate_limit=(-45, 45), interpolation=1, border_mode=4, value=0.0, mask_value=0.0, rotate_method='largest_box'),\n",
      "], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)\n"
     ]
    }
   ],
   "source": [
    "# You can also save training augmentations to the Hub too (and load it back)!\n",
    "#! Just make sure to provide key=\"train\" when saving and loading the augmentations.\n",
    "\n",
    "train_augmentations = A.Compose(\n",
    "    [\n",
    "        A.HorizontalFlip(p=0.5),\n",
    "        A.RandomBrightnessContrast(p=0.2),\n",
    "        A.ShiftScaleRotate(p=0.5),\n",
    "    ]\n",
    ")\n",
    "\n",
    "train_augmentations.save_pretrained(\n",
    "    directory_or_repo_on_the_hub, key=\"train\", push_to_hub=True\n",
    ")\n",
    "\n",
    "restored_train_augmentations = A.Compose.from_pretrained(\n",
    "    directory_or_repo_on_the_hub, key=\"train\"\n",
    ")\n",
    "print(restored_train_augmentations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See saved model and `albumentations` configs on the hub: https://huggingface.co/qubvel-hf/unet-with-transform/tree/main"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
